{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 316,
   "id": "df2d6b18-bd95-4c9e-a789-e8b79717b970",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib widget\n",
    "import seaborn as sns\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3459b50f-5e94-4d74-b35c-ce4f77028ea3",
   "metadata": {},
   "outputs": [],
   "source": [
    "original_df=pd.read_excel('附件1/附件1.xlsx',sheet_name='Data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "4476b11e-37f8-4796-ad32-9b72a6a5df75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>eventid</th>\n",
       "      <th>iyear</th>\n",
       "      <th>imonth</th>\n",
       "      <th>iday</th>\n",
       "      <th>approxdate</th>\n",
       "      <th>extended</th>\n",
       "      <th>resolution</th>\n",
       "      <th>country</th>\n",
       "      <th>country_txt</th>\n",
       "      <th>region</th>\n",
       "      <th>...</th>\n",
       "      <th>addnotes</th>\n",
       "      <th>scite1</th>\n",
       "      <th>scite2</th>\n",
       "      <th>scite3</th>\n",
       "      <th>dbsource</th>\n",
       "      <th>INT_LOG</th>\n",
       "      <th>INT_IDEO</th>\n",
       "      <th>INT_MISC</th>\n",
       "      <th>INT_ANY</th>\n",
       "      <th>related</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>199801010001</td>\n",
       "      <td>1998</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>NaT</td>\n",
       "      <td>34</td>\n",
       "      <td>Burundi</td>\n",
       "      <td>11</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>“Burundi Rebels, Ex-Rwandan Army Soldiers Blam...</td>\n",
       "      <td>“Burundi--Attack Reported on Bujumbura Airport...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CETIS</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>199801010002</td>\n",
       "      <td>1998</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>NaT</td>\n",
       "      <td>167</td>\n",
       "      <td>Russia</td>\n",
       "      <td>9</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>“Bomb injures 3 in Moscow subway system,” The ...</td>\n",
       "      <td>“Bomb injures 3 in Moscow subway,” Charleston ...</td>\n",
       "      <td>“Bomb Injures 3 Workers in Moscow Metro,” Los ...</td>\n",
       "      <td>CETIS</td>\n",
       "      <td>-9</td>\n",
       "      <td>-9</td>\n",
       "      <td>0</td>\n",
       "      <td>-9</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>199801010003</td>\n",
       "      <td>1998</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>NaT</td>\n",
       "      <td>603</td>\n",
       "      <td>United Kingdom</td>\n",
       "      <td>8</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>“Protestant gunmen kill Catholic in New Year's...</td>\n",
       "      <td>“Ulster Peace Shattered by Shooting: Catholic ...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CETIS</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>199801020001</td>\n",
       "      <td>1998</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>NaT</td>\n",
       "      <td>95</td>\n",
       "      <td>Iraq</td>\n",
       "      <td>10</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>“Iraq Condemns Attack on UNSCOM Baghdad Office...</td>\n",
       "      <td>Farouk Choukri , “Iraq, UN Officials Continue ...</td>\n",
       "      <td>“Iraqi Interior Minister on UNSCOM Attack, Kuw...</td>\n",
       "      <td>CETIS</td>\n",
       "      <td>-9</td>\n",
       "      <td>-9</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>199801020002</td>\n",
       "      <td>1998</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>NaT</td>\n",
       "      <td>155</td>\n",
       "      <td>West Bank and Gaza Strip</td>\n",
       "      <td>10</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>“Woman Shot,” The Philadelphia Inquirer, Janua...</td>\n",
       "      <td>“Israeli Woman Critically Hurt by Gunfire in W...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CETIS</td>\n",
       "      <td>-9</td>\n",
       "      <td>-9</td>\n",
       "      <td>0</td>\n",
       "      <td>-9</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 135 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        eventid  iyear  imonth  iday approxdate  extended resolution  country  \\\n",
       "0  199801010001   1998       1     1        NaN         0        NaT       34   \n",
       "1  199801010002   1998       1     1        NaN         0        NaT      167   \n",
       "2  199801010003   1998       1     1        NaN         0        NaT      603   \n",
       "3  199801020001   1998       1     2        NaN         0        NaT       95   \n",
       "4  199801020002   1998       1     2        NaN         0        NaT      155   \n",
       "\n",
       "                country_txt  region  ... addnotes  \\\n",
       "0                   Burundi      11  ...      NaN   \n",
       "1                    Russia       9  ...      NaN   \n",
       "2            United Kingdom       8  ...      NaN   \n",
       "3                      Iraq      10  ...      NaN   \n",
       "4  West Bank and Gaza Strip      10  ...      NaN   \n",
       "\n",
       "                                              scite1  \\\n",
       "0  “Burundi Rebels, Ex-Rwandan Army Soldiers Blam...   \n",
       "1  “Bomb injures 3 in Moscow subway system,” The ...   \n",
       "2  “Protestant gunmen kill Catholic in New Year's...   \n",
       "3  “Iraq Condemns Attack on UNSCOM Baghdad Office...   \n",
       "4  “Woman Shot,” The Philadelphia Inquirer, Janua...   \n",
       "\n",
       "                                              scite2  \\\n",
       "0  “Burundi--Attack Reported on Bujumbura Airport...   \n",
       "1  “Bomb injures 3 in Moscow subway,” Charleston ...   \n",
       "2  “Ulster Peace Shattered by Shooting: Catholic ...   \n",
       "3  Farouk Choukri , “Iraq, UN Officials Continue ...   \n",
       "4  “Israeli Woman Critically Hurt by Gunfire in W...   \n",
       "\n",
       "                                              scite3  dbsource  INT_LOG  \\\n",
       "0                                                NaN     CETIS        0   \n",
       "1  “Bomb Injures 3 Workers in Moscow Metro,” Los ...     CETIS       -9   \n",
       "2                                                NaN     CETIS        0   \n",
       "3  “Iraqi Interior Minister on UNSCOM Attack, Kuw...     CETIS       -9   \n",
       "4                                                NaN     CETIS       -9   \n",
       "\n",
       "   INT_IDEO INT_MISC INT_ANY  related  \n",
       "0         1        0       1      NaN  \n",
       "1        -9        0      -9      NaN  \n",
       "2         0        1       1      NaN  \n",
       "3        -9        1       1      NaN  \n",
       "4        -9        0      -9      NaN  \n",
       "\n",
       "[5 rows x 135 columns]"
      ]
     },
     "execution_count": 143,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "original_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "id": "8466c51b-024c-45e2-85b5-3be29aa52f82",
   "metadata": {},
   "outputs": [],
   "source": [
    "df=original_df.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "id": "d0fb1f2b-74f5-4196-938e-e3f906e4984a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"y-m\"]=df[\"iyear\"].map(str)+'-'+((df[\"imonth\"]-1)//3+1).map(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "id": "bc14a61f-6523-4099-98e6-f080889cd449",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        2015-1\n",
       "1        2015-1\n",
       "2        2015-1\n",
       "3        2015-1\n",
       "4        2015-1\n",
       "          ...  \n",
       "39447    2017-4\n",
       "39448    2017-4\n",
       "39449    2017-4\n",
       "39450    2017-4\n",
       "39451    2017-4\n",
       "Name: y-m, Length: 39452, dtype: object"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\"y-m\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "id": "246a2a4f-d100-4ec3-be38-50cb77190bcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ym_num=df.groupby(\"y-m\").size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 378,
   "id": "c40da46b-9bd6-490b-8137-909f42ec947d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "y-m\n",
       "2015-1    4012\n",
       "2015-2    3761\n",
       "2015-3    3660\n",
       "2015-4    3532\n",
       "2016-1    3460\n",
       "2016-2    3629\n",
       "2016-3    3321\n",
       "2016-4    3177\n",
       "2017-1    2719\n",
       "2017-2    3023\n",
       "2017-3    2800\n",
       "2017-4    2358\n",
       "dtype: int64"
      ]
     },
     "execution_count": 378,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ym_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 385,
   "id": "4ef2b298-abad-4492-935f-4961f7a4091d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ym_num.to_csv('ym_num.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 386,
   "id": "432312ce-79a1-4064-9b76-5ccf3a065a88",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GrayForecast():\n",
    "    # 初始化\n",
    "    def __init__(self, data, n):\n",
    "        \"\"\"\n",
    "        :param data: Series/np/list\n",
    "        :param n: 预测数量\n",
    "        \"\"\"\n",
    "        plt.rcParams['font.sans-serif'] = ['SimHei']\n",
    "        plt.rcParams['axes.unicode_minus'] = False\n",
    "        if isinstance(data, pd.Series):\n",
    "            self.data = data.values\n",
    "        elif isinstance(data, np.ndarray):\n",
    "            self.data = data\n",
    "        elif isinstance(data, list):\n",
    "            self.data = np.array(data)\n",
    "        self.level_check()\n",
    "        self.GM_11_build_model(n)\n",
    "        print(\"返回值为dataframe，可通过.res_df拿到, 可通过.plot_res画预测图\\n\", self.res_df)\n",
    "\n",
    "    def level_check(self):\n",
    "        # 数据级比校验\n",
    "        b = self.data[0]\n",
    "        n = len(self.data)\n",
    "        lambda_k = np.zeros(n - 1)\n",
    "        while (True):\n",
    "            for i in range(n - 1):\n",
    "                lambda_k[i] = self.data[i] / self.data[i + 1]\n",
    "            if max(lambda_k) < np.exp(2 / (n + 2)) and min(lambda_k) > np.exp(-2 / (n + 1)):\n",
    "                self.c = self.data[0] - b\n",
    "                print(f\"完成数据 级比校验, 平移变换c={self.c}\")\n",
    "                break\n",
    "            else:\n",
    "                self.data = self.data + 0.1\n",
    "\n",
    "    # GM(1,1)建模\n",
    "    def GM_11_build_model(self, n):\n",
    "        '''\n",
    "            灰色预测\n",
    "            x：序列，numpy对象\n",
    "            n:需要往后预测的个数\n",
    "        '''\n",
    "        x = self.data\n",
    "        # 累加生成（1-AGO）序列\n",
    "        x1 = x.cumsum()\n",
    "        # 紧邻均值生成序列\n",
    "        z1 = (x1[:len(x1) - 1] + x1[1:]) / 2.0\n",
    "        z1 = z1.reshape((len(z1), 1))\n",
    "        B = np.append(-z1, np.ones_like(z1), axis=1)\n",
    "        Y = x[1:].reshape((len(x) - 1, 1))\n",
    "        # a为发展系数 b为灰色作用量\n",
    "        [[a], [b]] = np.dot(np.dot(np.linalg.inv(np.dot(B.T, B)), B.T), Y)  # 计算参数\n",
    "        # 预测数据\n",
    "        fit_res = [x[0]]\n",
    "        for index in range(1, len(x) + n):\n",
    "            fit_res.append((x[0] - b / a) * (1 - np.exp(a)) * np.exp(-a * (index)))\n",
    "        # 数据还原\n",
    "        self.data -= self.c\n",
    "        fit_res -= self.c\n",
    "        self.res_df = pd.concat([pd.DataFrame({'原始值': self.data}), pd.DataFrame({'预测值': fit_res})], axis=1)\n",
    "        print(f\"发展系数a={a}, 灰色作用量b={b}\\n\")\n",
    "        self.verfify(self.data, fit_res, a)\n",
    "        return self.res_df\n",
    "\n",
    "    def verfify(self, x, predict, a):\n",
    "        S1_2 = x.var()  # 原序列方差\n",
    "        e = list()  # 残差序列\n",
    "        for index in range(x.shape[0]):\n",
    "            e.append(x[index] - predict[index])\n",
    "        S2_2 = np.array(e).var()  # 残差方差\n",
    "        C = S2_2 / S1_2  # 后验差比\n",
    "        if C <= 0.35:\n",
    "            assess = '后验差比<=0.35，模型精度等级为好'\n",
    "        elif C <= 0.5:\n",
    "            assess = '后验差比<=0.5，模型精度等级为合格'\n",
    "        elif C <= 0.65:\n",
    "            assess = '后验差比<=0.65，模型精度等级为勉强'\n",
    "        else:\n",
    "            assess = '后验差比>0.65，模型精度等级为不合格'\n",
    "        print(f\"后验差比={C}, {assess} \\n\")\n",
    "\n",
    "        # 级比偏差\n",
    "        a_ = (1 - 0.5 * a) / (1 + 0.5 * a)\n",
    "        delta = [np.nan]\n",
    "        for i in range(x.shape[0] - 1):\n",
    "            delta.append(1 - a_ * (x[i] / x[i + 1]))\n",
    "\n",
    "        self.res_df = pd.concat([self.res_df, pd.DataFrame({'残差': e}),\n",
    "                                 pd.DataFrame({'相对误差': list(map(lambda x: '{:.2%}'.format(x), np.abs(e / x)))}),\n",
    "                                 pd.DataFrame({'级比偏差': delta})\n",
    "                                 ],\n",
    "                                axis=1)\n",
    "\n",
    "    def plot_res(self, xlabel='', ylabel=''):\n",
    "        res_df = self.res_df\n",
    "        f, ax = plt.subplots(figsize=(8, 5))\n",
    "        sns.lineplot(x=res_df.index.tolist(), y=res_df['预测值'], linewidth=2, ax=ax)\n",
    "        sns.scatterplot(x=res_df.index.tolist(), y=res_df['原始值'], s=60, color='r', marker='v', ax=ax)\n",
    "        plt.fill_between(np.where(np.isnan(res_df[\"原始值\"]))[0], y1=min(plt.yticks()[0]), y2=max(plt.yticks()[0]),\n",
    "                         color='orange', alpha=0.2)\n",
    "        ax.set_xlabel(xlabel, fontsize=15)\n",
    "        ax.set_ylabel(ylabel, fontsize=15)\n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 387,
   "id": "698d2c86-a81e-45fc-ab9c-8bb05b70da42",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "完成数据 级比校验, 平移变换c=0.5999999999999979\n",
      "发展系数a=0.03244780698095551, 灰色作用量b=4.6945697959855766\n",
      "\n",
      "后验差比=0.11672033342303606, 后验差比<=0.35，模型精度等级为好 \n",
      "\n",
      "返回值为dataframe，可通过.res_df拿到, 可通过.plot_res画预测图\n",
      "       原始值       预测值        残差    相对误差      级比偏差\n",
      "0   4.012  4.012000  0.000000   0.00%       NaN\n",
      "1   3.761  3.871975 -0.110975   2.95% -0.032677\n",
      "2   3.660  3.729198 -0.069198   1.89%  0.005215\n",
      "3   3.532  3.590980 -0.058980   1.67% -0.003153\n",
      "4   3.460  3.457174  0.002826   0.08%  0.011785\n",
      "5   3.629  3.327641  0.301359   8.30%  0.077012\n",
      "6   3.321  3.202243  0.118757   3.58% -0.057852\n",
      "7   3.177  3.080849  0.096151   3.03% -0.011949\n",
      "8   2.719  2.963330 -0.244330   8.99% -0.131136\n",
      "9   3.023  2.849564  0.173436   5.74%  0.129281\n",
      "10  2.800  2.739429  0.060571   2.16% -0.045170\n",
      "11  2.358  2.632811 -0.274811  11.65% -0.149532\n",
      "12    NaN  2.529597       NaN     NaN       NaN\n",
      "13    NaN  2.429679       NaN     NaN       NaN\n",
      "14    NaN  2.332950       NaN     NaN       NaN\n",
      "15    NaN  2.239310       NaN     NaN       NaN\n",
      "16    NaN  2.148659       NaN     NaN       NaN\n",
      "17    NaN  2.060902       NaN     NaN       NaN\n",
      "18    NaN  1.975948       NaN     NaN       NaN\n",
      "19    NaN  1.893705       NaN     NaN       NaN\n",
      "20    NaN  1.814089       NaN     NaN       NaN\n",
      "21    NaN  1.737014       NaN     NaN       NaN\n"
     ]
    }
   ],
   "source": [
    "gm_model=GrayForecast(df_ym_num/1000,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 388,
   "id": "12874e49-b45d-4d7f-ad06-a6dbd6e2809d",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df=gm_model.res_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 389,
   "id": "14877d21-3ff3-47b1-823e-f054e0c71bc1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>原始值</th>\n",
       "      <th>预测值</th>\n",
       "      <th>残差</th>\n",
       "      <th>相对误差</th>\n",
       "      <th>级比偏差</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4.012</td>\n",
       "      <td>4.012000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00%</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.761</td>\n",
       "      <td>3.871975</td>\n",
       "      <td>-0.110975</td>\n",
       "      <td>2.95%</td>\n",
       "      <td>-0.032677</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.660</td>\n",
       "      <td>3.729198</td>\n",
       "      <td>-0.069198</td>\n",
       "      <td>1.89%</td>\n",
       "      <td>0.005215</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.532</td>\n",
       "      <td>3.590980</td>\n",
       "      <td>-0.058980</td>\n",
       "      <td>1.67%</td>\n",
       "      <td>-0.003153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.460</td>\n",
       "      <td>3.457174</td>\n",
       "      <td>0.002826</td>\n",
       "      <td>0.08%</td>\n",
       "      <td>0.011785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>3.629</td>\n",
       "      <td>3.327641</td>\n",
       "      <td>0.301359</td>\n",
       "      <td>8.30%</td>\n",
       "      <td>0.077012</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>3.321</td>\n",
       "      <td>3.202243</td>\n",
       "      <td>0.118757</td>\n",
       "      <td>3.58%</td>\n",
       "      <td>-0.057852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>3.177</td>\n",
       "      <td>3.080849</td>\n",
       "      <td>0.096151</td>\n",
       "      <td>3.03%</td>\n",
       "      <td>-0.011949</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>2.719</td>\n",
       "      <td>2.963330</td>\n",
       "      <td>-0.244330</td>\n",
       "      <td>8.99%</td>\n",
       "      <td>-0.131136</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>3.023</td>\n",
       "      <td>2.849564</td>\n",
       "      <td>0.173436</td>\n",
       "      <td>5.74%</td>\n",
       "      <td>0.129281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>2.800</td>\n",
       "      <td>2.739429</td>\n",
       "      <td>0.060571</td>\n",
       "      <td>2.16%</td>\n",
       "      <td>-0.045170</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>2.358</td>\n",
       "      <td>2.632811</td>\n",
       "      <td>-0.274811</td>\n",
       "      <td>11.65%</td>\n",
       "      <td>-0.149532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>NaN</td>\n",
       "      <td>2.529597</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>NaN</td>\n",
       "      <td>2.429679</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>NaN</td>\n",
       "      <td>2.332950</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>NaN</td>\n",
       "      <td>2.239310</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>NaN</td>\n",
       "      <td>2.148659</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>NaN</td>\n",
       "      <td>2.060902</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>NaN</td>\n",
       "      <td>1.975948</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>NaN</td>\n",
       "      <td>1.893705</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>NaN</td>\n",
       "      <td>1.814089</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>NaN</td>\n",
       "      <td>1.737014</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      原始值       预测值        残差    相对误差      级比偏差\n",
       "0   4.012  4.012000  0.000000   0.00%       NaN\n",
       "1   3.761  3.871975 -0.110975   2.95% -0.032677\n",
       "2   3.660  3.729198 -0.069198   1.89%  0.005215\n",
       "3   3.532  3.590980 -0.058980   1.67% -0.003153\n",
       "4   3.460  3.457174  0.002826   0.08%  0.011785\n",
       "5   3.629  3.327641  0.301359   8.30%  0.077012\n",
       "6   3.321  3.202243  0.118757   3.58% -0.057852\n",
       "7   3.177  3.080849  0.096151   3.03% -0.011949\n",
       "8   2.719  2.963330 -0.244330   8.99% -0.131136\n",
       "9   3.023  2.849564  0.173436   5.74%  0.129281\n",
       "10  2.800  2.739429  0.060571   2.16% -0.045170\n",
       "11  2.358  2.632811 -0.274811  11.65% -0.149532\n",
       "12    NaN  2.529597       NaN     NaN       NaN\n",
       "13    NaN  2.429679       NaN     NaN       NaN\n",
       "14    NaN  2.332950       NaN     NaN       NaN\n",
       "15    NaN  2.239310       NaN     NaN       NaN\n",
       "16    NaN  2.148659       NaN     NaN       NaN\n",
       "17    NaN  2.060902       NaN     NaN       NaN\n",
       "18    NaN  1.975948       NaN     NaN       NaN\n",
       "19    NaN  1.893705       NaN     NaN       NaN\n",
       "20    NaN  1.814089       NaN     NaN       NaN\n",
       "21    NaN  1.737014       NaN     NaN       NaN"
      ]
     },
     "execution_count": 389,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 384,
   "id": "5c198b7c-6239-4601-ae5c-c306316d3c59",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b81feb8c3f94da784bc6e6d199b2703",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "gm_model.plot_res(xlabel=\"2015-2017季度\",ylabel=\"案件数量\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "746566e1-419b-4ea7-9dc7-400780a4464e",
   "metadata": {},
   "source": [
    "## 线性拟合\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 285,
   "id": "ebb303ae-25ae-4c46-b959-0964db15ca1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.metrics import mean_squared_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 286,
   "id": "ef866073-3198-4a44-b511-6ae210f7dba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def polynomial_model(degree=1):\n",
    "    polynomial_features = PolynomialFeatures(degree=degree, include_bias=False)\n",
    "    linear_regression = LinearRegression(normalize=True)\n",
    "    pipeline = Pipeline([(\"polynomial_features\", polynomial_features),\n",
    "                         (\"linear_regression\", linear_regression)])\n",
    "    return pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 287,
   "id": "80615ee1-ae0c-4859-933a-9961cdd2c499",
   "metadata": {},
   "outputs": [],
   "source": [
    "X=np.arange(12).reshape(-1,1)\n",
    "Y=np.array(df_ym_num.values).reshape(-1,1)/1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 288,
   "id": "6e87522e-d634-47c2-802c-d5afe2b9a2e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[4.012],\n",
       "       [3.761],\n",
       "       [3.66 ],\n",
       "       [3.532],\n",
       "       [3.46 ],\n",
       "       [3.629],\n",
       "       [3.321],\n",
       "       [3.177],\n",
       "       [2.719],\n",
       "       [3.023],\n",
       "       [2.8  ],\n",
       "       [2.358]])"
      ]
     },
     "execution_count": 288,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "id": "9d829377-8145-4156-94c2-e65d1451cf93",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'model': Pipeline(steps=[('polynomial_features',\n",
       "                   PolynomialFeatures(degree=1, include_bias=False)),\n",
       "                  ('linear_regression', LinearRegression(normalize=True))]),\n",
       "  'degree': 1,\n",
       "  'score': 0.8962555793414528,\n",
       "  'mse': 0.022496360916860938,\n",
       "  'RMSE': 0.14998786923235138},\n",
       " {'model': Pipeline(steps=[('polynomial_features',\n",
       "                   PolynomialFeatures(degree=3, include_bias=False)),\n",
       "                  ('linear_regression', LinearRegression(normalize=True))]),\n",
       "  'degree': 3,\n",
       "  'score': 0.9184471768135963,\n",
       "  'mse': 0.017684244921744915,\n",
       "  'RMSE': 0.13298212256444442},\n",
       " {'model': Pipeline(steps=[('polynomial_features',\n",
       "                   PolynomialFeatures(degree=5, include_bias=False)),\n",
       "                  ('linear_regression', LinearRegression(normalize=True))]),\n",
       "  'degree': 5,\n",
       "  'score': 0.9396800629966126,\n",
       "  'mse': 0.013080019770670167,\n",
       "  'RMSE': 0.11436791407851316},\n",
       " {'model': Pipeline(steps=[('polynomial_features',\n",
       "                   PolynomialFeatures(degree=10, include_bias=False)),\n",
       "                  ('linear_regression', LinearRegression(normalize=True))]),\n",
       "  'degree': 10,\n",
       "  'score': 0.9865833333202765,\n",
       "  'mse': 0.0029093244148683333,\n",
       "  'RMSE': 0.05393815361011474}]"
      ]
     },
     "execution_count": 289,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "degrees = [1, 3, 5, 10]\n",
    "results = []\n",
    "for d in degrees:\n",
    "    model = polynomial_model(degree=d)\n",
    "    model.fit(X, Y)\n",
    "    train_score = model.score(X, Y)\n",
    "    mse = mean_squared_error(Y, model.predict(X))\n",
    "    results.append({\"model\": model, \"degree\": d, \"score\": train_score, \"mse\": mse,'RMSE':np.sqrt(mse)})\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 310,
   "id": "5e06cc3d-c9ed-4ffe-8692-bad427984be9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8b3bdc31ffa43d8888c301c26dcb0e1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib.figure import SubplotParams\n",
    "plt.figure(figsize=(9, 6), dpi=200, subplotpars=SubplotParams(hspace=0.3))\n",
    "all_x=np.arange(22).reshape(-1,1)\n",
    "for i, r in enumerate(results):\n",
    "    fig = plt.subplot(2, 2, i + 1)\n",
    "    plt.xlim(0, 22)\n",
    "    plt.title(\"LinearRegression degree={}\".format(r[\"degree\"]))\n",
    "    plt.scatter(X, Y, s=5, c='b', alpha=0.5)\n",
    "    plt.plot(all_x, r[\"model\"].predict(all_x), 'r-')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "id": "c1e4891a-909a-4b6b-b10b-edfb4cd6b615",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "灰色预测 {'mse': 0.025310297904259395, 'RMSE': 0.15909210509720273}\n"
     ]
    }
   ],
   "source": [
    "mse = mean_squared_error(res_df['原始值'][:12], res_df['预测值'][:12])\n",
    "print(\"灰色预测\",{\"mse\": mse,'RMSE':np.sqrt(mse)})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4f60b7d-d0ac-4006-b915-23d7f4257c39",
   "metadata": {},
   "source": [
    "## 幂函数拟合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 312,
   "id": "1dd975f7-2e5e-4489-9636-0d73a997cfb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import curve_fit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 333,
   "id": "ff4488e7-2ad7-4d04-9900-59482f44920c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4.700944628373652, 3.7009446283736516, 3.5019459261666057, 3.367654374534397, 3.263346740479471, 3.176851924048842, 3.1023313443550826, 3.036489460403104, 2.977266626492938, 2.923281727390908, 2.873559453844978, 2.8273831520939607]\n",
      "[4.012 3.761 3.66  3.532 3.46  3.629 3.321 3.177 2.719 3.023 2.8   2.358]\n",
      "12\n"
     ]
    }
   ],
   "source": [
    "def fund(x, a, b):\n",
    "    return -x**a + b\n",
    "xdata=np.arange(12)\n",
    "ydata=np.array(df_ym_num.values)/1000\n",
    "popt, pcov = curve_fit(fund, xdata, ydata)\n",
    "#popt数组中，三个值分别是待求参数a,b,c\n",
    "y2 = [fund(i, popt[0],popt[1]) for i in xdata]\n",
    "print(y2)\n",
    "print(ydata)\n",
    "print(len(y2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 334,
   "id": "5a5ea1c1-3896-434f-91a2-eaa2cf64ea85",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "806a671697fe4232b76bb35408d122bc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "plt.scatter(xdata, ydata, s=5, c='b', alpha=0.5)\n",
    "plt.plot(xdata,y2,'r-')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 337,
   "id": "284e99f8-353a-4f15-838d-a5bd584fe171",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "幂函数预测 {'mse': 0.09527403535238894, 'RMSE': 0.308664924072025}\n"
     ]
    }
   ],
   "source": [
    "mse = mean_squared_error(ydata, y2)\n",
    "print(\"幂函数预测\",{\"mse\": mse,'RMSE':np.sqrt(mse)})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e96ce8-bd66-400b-845c-3a8ea7a37908",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
